import argparse
import torch
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader
from sentence_transformers import SentenceTransformer, models
from text_pair_dataset import TextPairDataset
from tqdm import tqdm
import os

from heads import get_matching_head
from loss_func import get_loss_func 
import random
import numpy as np

def set_seed(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)  # if using multi-GPU

    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

def build_model(base_model="sentence-transformers/all-MiniLM-L6-v2"):
    embedder = models.Transformer(base_model, model_args={"trust_remote_code": True}, config_args={"trust_remote_code": True})
    pooling = models.Pooling(embedder.get_word_embedding_dimension())
    embedding_model = SentenceTransformer(modules=[embedder, pooling], trust_remote_code=True)
    return embedding_model

def train_and_save(args):
    if os.path.exists(os.path.join(args.save_dir, "embedding_model")) and \
       os.path.exists(os.path.join(args.save_dir, "matching_head.pt")):
        print(f"Model already exists at {args.save_dir}. Skipping training.")
        return
    set_seed(args.seed)

    embedding_model = build_model(args.model_name).cuda().train()
    embedding_dim = embedding_model.get_sentence_embedding_dimension()
    matching_head = get_matching_head(args.head_type, embedding_dim).cuda().train()

    if args.freeze_embedding_model:
        for param in embedding_model.parameters():
            param.requires_grad = False

    dataset = TextPairDataset(args.pos_data_path, args.neg_data_path, limit=args.limit)
    dataloader = DataLoader(dataset, batch_size=args.batch_size, shuffle=True)

    if args.freeze_embedding_model:
        optimizer = optim.Adam(matching_head.parameters(), lr=args.lr)
    else:
        optimizer = optim.Adam(list(embedding_model.parameters()) + list(matching_head.parameters()), lr=args.lr)

    loss_fn = get_loss_func(args.loss_type)

    for epoch in range(args.num_epochs):
        total_loss = 0
        for answers, evidences, labels in tqdm(dataloader, desc=f"Epoch {epoch+1}/{args.num_epochs}"):
            labels = labels.float().cuda()

            # emb_a = embedding_model.encode(answers, convert_to_tensor=True, normalize_embeddings=True)
            # emb_b = embedding_model.encode(evidences, convert_to_tensor=True, normalize_embeddings=True)

            
            tokenized_a = {k: v.to(embedding_model.device) for k, v in embedding_model.tokenize(answers).items()}
            tokenized_b = {k: v.to(embedding_model.device) for k, v in embedding_model.tokenize(evidences).items()}

            emb_a = embedding_model(tokenized_a)["sentence_embedding"]
            emb_b = embedding_model(tokenized_b)["sentence_embedding"]



            features = {"embedding_a": emb_a, "embedding_b": emb_b}
            outputs = matching_head(features)
            logits = outputs["logits"].squeeze(-1)

            loss = loss_fn(logits, labels, emb_a=emb_a, emb_b=emb_b)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            total_loss += loss.item()

        print(f"Epoch {epoch+1}, Loss: {total_loss/len(dataloader):.4f}")

        epoch_save_dir = os.path.join(args.save_dir, f"epoch-{epoch+1}")
        os.makedirs(epoch_save_dir, exist_ok=True)
        embedding_model.save(os.path.join(epoch_save_dir, "embedding_model"))
        torch.save(matching_head.state_dict(), os.path.join(epoch_save_dir, "matching_head.pt"))
        print(f"Models for epoch {epoch+1} saved to {epoch_save_dir}/")

    os.makedirs(args.save_dir, exist_ok=True)
    embedding_model.save(os.path.join(args.save_dir, "embedding_model"))
    torch.save(matching_head.state_dict(), os.path.join(args.save_dir, "matching_head.pt"))
    print(f"Models saved to {args.save_dir}/")

if __name__ == "__main__":
    parser = argparse.ArgumentParser()

    parser.add_argument("--model_name", type=str, default="sentence-transformers/all-MiniLM-L6-v2")
    parser.add_argument("--save_dir", type=str, default="saved_model_dir_1wdata")
    parser.add_argument("--pos_data_path", type=str, required=True)
    parser.add_argument("--neg_data_path", type=str, required=True)
    parser.add_argument("--batch_size", type=int, default=128)
    parser.add_argument("--lr", type=float, default=2e-5)
    parser.add_argument("--num_epochs", type=int, default=3)
    parser.add_argument("--limit", type=int, default=0)
    parser.add_argument("--seed", type=int, default=42, help="Random seed for reproducibility")

    parser.add_argument("--head_type", type=str, default="base", choices=["base", "deep_mlp", "cos_sim", "residual", "cross_attn", "feature", "cos_sim_deeper"])
    parser.add_argument("--loss_type", type=str, default="bce", choices=["bce", "focal", "contrastive", "circle", "weighted_bce", "auc_margin"])

    parser.add_argument("--freeze_embedding_model", action="store_true",
                        help="If set, freeze the embedding model and only train the matching head.")


    args = parser.parse_args()
    train_and_save(args)
